{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "readme_example.ipynb",
      "provenance": [],
      "authorship_tag": "ABX9TyPpME0Qdi/m3VZQ+jNj39dT",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Lednik7/CLIP-ONNX/blob/main/examples/readme_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Restart colab session after installation\n",
        "Reload the session if something doesn't work"
      ],
      "metadata": {
        "id": "whlsBiJgR8le"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/Lednik7/CLIP-ONNX.git\n",
        "!pip install git+https://github.com/openai/CLIP.git\n",
        "!pip install onnxruntime-gpu"
      ],
      "metadata": {
        "id": "HnbpAkvuR73L"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!wget -c -O CLIP.png https://github.com/openai/CLIP/blob/main/CLIP.png?raw=true"
      ],
      "metadata": {
        "id": "tqy0zKM4R-7M"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi # CPU Provider"
      ],
      "metadata": {
        "id": "eKqETHL4YscZ",
        "outputId": "7ff0bc18-fb40-4296-ab05-b079043e46a1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import onnxruntime\n",
        "\n",
        "print(onnxruntime.get_device()) # priority device"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x8IN72OnSAIh",
        "outputId": "81d14047-91fa-4a5c-a1e3-f5b550556591"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CPU\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## CPU inference mode"
      ],
      "metadata": {
        "id": "U1Pr-YTtSEhs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)"
      ],
      "metadata": {
        "id": "gZTxanR26knr"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import clip\n",
        "from PIL import Image\n",
        "import numpy as np\n",
        "\n",
        "# onnx cannot export with cuda\n",
        "model, preprocess = clip.load(\"ViT-B/32\", device=\"cpu\", jit=False)\n",
        "\n",
        "# batch first\n",
        "image = preprocess(Image.open(\"CLIP.png\")).unsqueeze(0).cpu() # [1, 3, 224, 224]\n",
        "image_onnx = image.detach().cpu().numpy().astype(np.float32)\n",
        "\n",
        "# batch first\n",
        "text = clip.tokenize([\"a diagram\", \"a dog\", \"a cat\"]).cpu() # [3, 77]\n",
        "text_onnx = text.detach().cpu().numpy().astype(np.int32)"
      ],
      "metadata": {
        "id": "rPwc6A2SSGyl"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from clip_onnx import clip_onnx, attention\n",
        "# clip.model.ResidualAttentionBlock.attention = attention\n",
        "\n",
        "visual_path = \"clip_visual.onnx\"\n",
        "textual_path = \"clip_textual.onnx\"\n",
        "\n",
        "onnx_model = clip_onnx(model, visual_path=visual_path, textual_path=textual_path)\n",
        "onnx_model.convert2onnx(image, text, verbose=True)\n",
        "# ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']\n",
        "onnx_model.start_sessions(providers=[\"CPUExecutionProvider\"]) # cpu mode"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oYM5FDSGSJBW",
        "outputId": "816705b1-3829-4424-c7c4-5426cf21cc18"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[CLIP ONNX] Start convert visual model\n",
            "[CLIP ONNX] Start check visual model\n",
            "[CLIP ONNX] Start convert textual model\n",
            "[CLIP ONNX] Start check textual model\n",
            "[CLIP ONNX] Models converts successfully\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "image_features = onnx_model.encode_image(image_onnx)\n",
        "text_features = onnx_model.encode_text(text_onnx)\n",
        "\n",
        "logits_per_image, logits_per_text = onnx_model(image_onnx, text_onnx)\n",
        "probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()\n",
        "\n",
        "print(\"Label probs:\", probs)  # prints: [[0.9927937  0.00421067 0.00299571]]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tYVuk72nSLw6",
        "outputId": "41608059-3732-4ea7-c619-66f803af4185"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Label probs: [[0.9927937  0.00421067 0.00299571]]\n"
          ]
        }
      ]
    }
  ]
}
